"""Storage for auth models."""
import asyncio
from collections import OrderedDict
from datetime import timedelta
import hmac
from logging import getLogger
from typing import Any, Dict, List, Optional

from homeassistant.auth.const import ACCESS_TOKEN_EXPIRATION
from homeassistant.core import HomeAssistant, callback
from homeassistant.util import dt as dt_util

from . import models
from .const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY, GROUP_ID_USER
from .permissions import PermissionLookup, system_policies
from .permissions.types import PolicyType

STORAGE_VERSION = 1
STORAGE_KEY = "auth"
GROUP_NAME_ADMIN = "Administrators"
GROUP_NAME_USER = "Users"
GROUP_NAME_READ_ONLY = "Read Only"


class AuthStore:
    """Stores authentication info.

    Any mutation to an object should happen inside the auth store.

    The auth store is lazy. It won't load the data from disk until a method is
    called that needs it.
    """

    def __init__(self, hass: HomeAssistant) -> None:
        """Initialize the auth store."""
        self.hass = hass
        self._users: Optional[Dict[str, models.User]] = None
        self._groups: Optional[Dict[str, models.Group]] = None
        self._perm_lookup: Optional[PermissionLookup] = None
        self._store = hass.helpers.storage.Store(
            STORAGE_VERSION, STORAGE_KEY, private=True
        )
        self._lock = asyncio.Lock()

    async def async_get_groups(self) -> List[models.Group]:
        """Retrieve all users."""
        if self._groups is None:
            await self._async_load()
            assert self._groups is not None

        return list(self._groups.values())

    async def async_get_group(self, group_id: str) -> Optional[models.Group]:
        """Retrieve all users."""
        if self._groups is None:
            await self._async_load()
            assert self._groups is not None

        return self._groups.get(group_id)

    async def async_get_users(self) -> List[models.User]:
        """Retrieve all users."""
        if self._users is None:
            await self._async_load()
            assert self._users is not None

        return list(self._users.values())

    async def async_get_user(self, user_id: str) -> Optional[models.User]:
        """Retrieve a user by id."""
        if self._users is None:
            await self._async_load()
            assert self._users is not None

        return self._users.get(user_id)

    async def async_create_user(
        self,
        name: Optional[str],
        is_owner: Optional[bool] = None,
        is_active: Optional[bool] = None,
        system_generated: Optional[bool] = None,
        credentials: Optional[models.Credentials] = None,
        group_ids: Optional[List[str]] = None,
    ) -> models.User:
        """Create a new user."""
        if self._users is None:
            await self._async_load()

        assert self._users is not None
        assert self._groups is not None

        groups = []
        for group_id in group_ids or []:
            group = self._groups.get(group_id)
            if group is None:
                raise ValueError(f"Invalid group specified {group_id}")
            groups.append(group)

        kwargs: Dict[str, Any] = {
            "name": name,
            # Until we get group management, we just put everyone in the
            # same group.
            "groups": groups,
            "perm_lookup": self._perm_lookup,
        }

        if is_owner is not None:
            kwargs["is_owner"] = is_owner

        if is_active is not None:
            kwargs["is_active"] = is_active

        if system_generated is not None:
            kwargs["system_generated"] = system_generated

        new_user = models.User(**kwargs)

        self._users[new_user.id] = new_user

        if credentials is None:
            self._async_schedule_save()
            return new_user

        # Saving is done inside the link.
        await self.async_link_user(new_user, credentials)
        return new_user

    async def async_link_user(
        self, user: models.User, credentials: models.Credentials
    ) -> None:
        """Add credentials to an existing user."""
        user.credentials.append(credentials)
        self._async_schedule_save()
        credentials.is_new = False

    async def async_remove_user(self, user: models.User) -> None:
        """Remove a user."""
        if self._users is None:
            await self._async_load()
            assert self._users is not None

        self._users.pop(user.id)
        self._async_schedule_save()

    async def async_update_user(
        self,
        user: models.User,
        name: Optional[str] = None,
        is_active: Optional[bool] = None,
        group_ids: Optional[List[str]] = None,
    ) -> None:
        """Update a user."""
        assert self._groups is not None

        if group_ids is not None:
            groups = []
            for grid in group_ids:
                group = self._groups.get(grid)
                if group is None:
                    raise ValueError("Invalid group specified.")
                groups.append(group)

            user.groups = groups
            user.invalidate_permission_cache()

        for attr_name, value in (("name", name), ("is_active", is_active)):
            if value is not None:
                setattr(user, attr_name, value)

        self._async_schedule_save()

    async def async_activate_user(self, user: models.User) -> None:
        """Activate a user."""
        user.is_active = True
        self._async_schedule_save()

    async def async_deactivate_user(self, user: models.User) -> None:
        """Activate a user."""
        user.is_active = False
        self._async_schedule_save()

    async def async_remove_credentials(self, credentials: models.Credentials) -> None:
        """Remove credentials."""
        if self._users is None:
            await self._async_load()
            assert self._users is not None

        for user in self._users.values():
            found = None

            for index, cred in enumerate(user.credentials):
                if cred is credentials:
                    found = index
                    break

            if found is not None:
                user.credentials.pop(found)
                break

        self._async_schedule_save()

    async def async_create_refresh_token(
        self,
        user: models.User,
        client_id: Optional[str] = None,
        client_name: Optional[str] = None,
        client_icon: Optional[str] = None,
        token_type: str = models.TOKEN_TYPE_NORMAL,
        access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
    ) -> models.RefreshToken:
        """Create a new token for a user."""
        kwargs: Dict[str, Any] = {
            "user": user,
            "client_id": client_id,
            "token_type": token_type,
            "access_token_expiration": access_token_expiration,
        }
        if client_name:
            kwargs["client_name"] = client_name
        if client_icon:
            kwargs["client_icon"] = client_icon

        refresh_token = models.RefreshToken(**kwargs)
        user.refresh_tokens[refresh_token.id] = refresh_token

        self._async_schedule_save()
        return refresh_token

    async def async_remove_refresh_token(
        self, refresh_token: models.RefreshToken
    ) -> None:
        """Remove a refresh token."""
        if self._users is None:
            await self._async_load()
            assert self._users is not None

        for user in self._users.values():
            if user.refresh_tokens.pop(refresh_token.id, None):
                self._async_schedule_save()
                break

    async def async_get_refresh_token(
        self, token_id: str
    ) -> Optional[models.RefreshToken]:
        """Get refresh token by id."""
        if self._users is None:
            await self._async_load()
            assert self._users is not None

        for user in self._users.values():
            refresh_token = user.refresh_tokens.get(token_id)
            if refresh_token is not None:
                return refresh_token

        return None

    async def async_get_refresh_token_by_token(
        self, token: str
    ) -> Optional[models.RefreshToken]:
        """Get refresh token by token."""
        if self._users is None:
            await self._async_load()
            assert self._users is not None

        found = None

        for user in self._users.values():
            for refresh_token in user.refresh_tokens.values():
                if hmac.compare_digest(refresh_token.token, token):
                    found = refresh_token

        return found

    @callback
    def async_log_refresh_token_usage(
        self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
    ) -> None:
        """Update refresh token last used information."""
        refresh_token.last_used_at = dt_util.utcnow()
        refresh_token.last_used_ip = remote_ip
        self._async_schedule_save()

    async def _async_load(self) -> None:
        """Load the users."""
        async with self._lock:
            if self._users is not None:
                return
            await self._async_load_task()

    async def _async_load_task(self) -> None:
        """Load the users."""
        [ent_reg, dev_reg, data] = await asyncio.gather(
            self.hass.helpers.entity_registry.async_get_registry(),
            self.hass.helpers.device_registry.async_get_registry(),
            self._store.async_load(),
        )

        # Make sure that we're not overriding data if 2 loads happened at the
        # same time
        if self._users is not None:
            return

        self._perm_lookup = perm_lookup = PermissionLookup(ent_reg, dev_reg)

        if data is None:
            self._set_defaults()
            return

        users: Dict[str, models.User] = OrderedDict()
        groups: Dict[str, models.Group] = OrderedDict()

        # Soft-migrating data as we load. We are going to make sure we have a
        # read only group and an admin group. There are two states that we can
        # migrate from:
        # 1. Data from a recent version which has a single group without policy
        # 2. Data from old version which has no groups
        has_admin_group = False
        has_user_group = False
        has_read_only_group = False
        group_without_policy = None

        # When creating objects we mention each attribute explicitly. This
        # prevents crashing if user rolls back HA version after a new property
        # was added.

        for group_dict in data.get("groups", []):
            policy: Optional[PolicyType] = None

            if group_dict["id"] == GROUP_ID_ADMIN:
                has_admin_group = True

                name = GROUP_NAME_ADMIN
                policy = system_policies.ADMIN_POLICY
                system_generated = True

            elif group_dict["id"] == GROUP_ID_USER:
                has_user_group = True

                name = GROUP_NAME_USER
                policy = system_policies.USER_POLICY
                system_generated = True

            elif group_dict["id"] == GROUP_ID_READ_ONLY:
                has_read_only_group = True

                name = GROUP_NAME_READ_ONLY
                policy = system_policies.READ_ONLY_POLICY
                system_generated = True

            else:
                name = group_dict["name"]
                policy = group_dict.get("policy")
                system_generated = False

            # We don't want groups without a policy that are not system groups
            # This is part of migrating from state 1
            if policy is None:
                group_without_policy = group_dict["id"]
                continue

            groups[group_dict["id"]] = models.Group(
                id=group_dict["id"],
                name=name,
                policy=policy,
                system_generated=system_generated,
            )

        # If there are no groups, add all existing users to the admin group.
        # This is part of migrating from state 2
        migrate_users_to_admin_group = not groups and group_without_policy is None

        # If we find a no_policy_group, we need to migrate all users to the
        # admin group. We only do this if there are no other groups, as is
        # the expected state. If not expected state, not marking people admin.
        # This is part of migrating from state 1
        if groups and group_without_policy is not None:
            group_without_policy = None

        # This is part of migrating from state 1 and 2
        if not has_admin_group:
            admin_group = _system_admin_group()
            groups[admin_group.id] = admin_group

        # This is part of migrating from state 1 and 2
        if not has_read_only_group:
            read_only_group = _system_read_only_group()
            groups[read_only_group.id] = read_only_group

        if not has_user_group:
            user_group = _system_user_group()
            groups[user_group.id] = user_group

        for user_dict in data["users"]:
            # Collect the users group.
            user_groups = []
            for group_id in user_dict.get("group_ids", []):
                # This is part of migrating from state 1
                if group_id == group_without_policy:
                    group_id = GROUP_ID_ADMIN
                user_groups.append(groups[group_id])

            # This is part of migrating from state 2
            if not user_dict["system_generated"] and migrate_users_to_admin_group:
                user_groups.append(groups[GROUP_ID_ADMIN])

            users[user_dict["id"]] = models.User(
                name=user_dict["name"],
                groups=user_groups,
                id=user_dict["id"],
                is_owner=user_dict["is_owner"],
                is_active=user_dict["is_active"],
                system_generated=user_dict["system_generated"],
                perm_lookup=perm_lookup,
            )

        for cred_dict in data["credentials"]:
            users[cred_dict["user_id"]].credentials.append(
                models.Credentials(
                    id=cred_dict["id"],
                    is_new=False,
                    auth_provider_type=cred_dict["auth_provider_type"],
                    auth_provider_id=cred_dict["auth_provider_id"],
                    data=cred_dict["data"],
                )
            )

        for rt_dict in data["refresh_tokens"]:
            # Filter out the old keys that don't have jwt_key (pre-0.76)
            if "jwt_key" not in rt_dict:
                continue

            created_at = dt_util.parse_datetime(rt_dict["created_at"])
            if created_at is None:
                getLogger(__name__).error(
                    "Ignoring refresh token %(id)s with invalid created_at "
                    "%(created_at)s for user_id %(user_id)s",
                    rt_dict,
                )
                continue

            token_type = rt_dict.get("token_type")
            if token_type is None:
                if rt_dict["client_id"] is None:
                    token_type = models.TOKEN_TYPE_SYSTEM
                else:
                    token_type = models.TOKEN_TYPE_NORMAL

            # old refresh_token don't have last_used_at (pre-0.78)
            last_used_at_str = rt_dict.get("last_used_at")
            if last_used_at_str:
                last_used_at = dt_util.parse_datetime(last_used_at_str)
            else:
                last_used_at = None

            token = models.RefreshToken(
                id=rt_dict["id"],
                user=users[rt_dict["user_id"]],
                client_id=rt_dict["client_id"],
                # use dict.get to keep backward compatibility
                client_name=rt_dict.get("client_name"),
                client_icon=rt_dict.get("client_icon"),
                token_type=token_type,
                created_at=created_at,
                access_token_expiration=timedelta(
                    seconds=rt_dict["access_token_expiration"]
                ),
                token=rt_dict["token"],
                jwt_key=rt_dict["jwt_key"],
                last_used_at=last_used_at,
                last_used_ip=rt_dict.get("last_used_ip"),
            )
            users[rt_dict["user_id"]].refresh_tokens[token.id] = token

        self._groups = groups
        self._users = users

    @callback
    def _async_schedule_save(self) -> None:
        """Save users."""
        if self._users is None:
            return

        self._store.async_delay_save(self._data_to_save, 1)

    @callback
    def _data_to_save(self) -> Dict:
        """Return the data to store."""
        assert self._users is not None
        assert self._groups is not None

        users = [
            {
                "id": user.id,
                "group_ids": [group.id for group in user.groups],
                "is_owner": user.is_owner,
                "is_active": user.is_active,
                "name": user.name,
                "system_generated": user.system_generated,
            }
            for user in self._users.values()
        ]

        groups = []
        for group in self._groups.values():
            g_dict: Dict[str, Any] = {
                "id": group.id,
                # Name not read for sys groups. Kept here for backwards compat
                "name": group.name,
            }

            if not group.system_generated:
                g_dict["policy"] = group.policy

            groups.append(g_dict)

        credentials = [
            {
                "id": credential.id,
                "user_id": user.id,
                "auth_provider_type": credential.auth_provider_type,
                "auth_provider_id": credential.auth_provider_id,
                "data": credential.data,
            }
            for user in self._users.values()
            for credential in user.credentials
        ]

        refresh_tokens = [
            {
                "id": refresh_token.id,
                "user_id": user.id,
                "client_id": refresh_token.client_id,
                "client_name": refresh_token.client_name,
                "client_icon": refresh_token.client_icon,
                "token_type": refresh_token.token_type,
                "created_at": refresh_token.created_at.isoformat(),
                "access_token_expiration": refresh_token.access_token_expiration.total_seconds(),
                "token": refresh_token.token,
                "jwt_key": refresh_token.jwt_key,
                "last_used_at": refresh_token.last_used_at.isoformat()
                if refresh_token.last_used_at
                else None,
                "last_used_ip": refresh_token.last_used_ip,
            }
            for user in self._users.values()
            for refresh_token in user.refresh_tokens.values()
        ]

        return {
            "users": users,
            "groups": groups,
            "credentials": credentials,
            "refresh_tokens": refresh_tokens,
        }

    def _set_defaults(self) -> None:
        """Set default values for auth store."""
        self._users = OrderedDict()

        groups: Dict[str, models.Group] = OrderedDict()
        admin_group = _system_admin_group()
        groups[admin_group.id] = admin_group
        user_group = _system_user_group()
        groups[user_group.id] = user_group
        read_only_group = _system_read_only_group()
        groups[read_only_group.id] = read_only_group
        self._groups = groups


def _system_admin_group() -> models.Group:
    """Create system admin group."""
    return models.Group(
        name=GROUP_NAME_ADMIN,
        id=GROUP_ID_ADMIN,
        policy=system_policies.ADMIN_POLICY,
        system_generated=True,
    )


def _system_user_group() -> models.Group:
    """Create system user group."""
    return models.Group(
        name=GROUP_NAME_USER,
        id=GROUP_ID_USER,
        policy=system_policies.USER_POLICY,
        system_generated=True,
    )


def _system_read_only_group() -> models.Group:
    """Create read only group."""
    return models.Group(
        name=GROUP_NAME_READ_ONLY,
        id=GROUP_ID_READ_ONLY,
        policy=system_policies.READ_ONLY_POLICY,
        system_generated=True,
    )
